{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 复合路径"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import tvm\n",
    "from tvm import relax\n",
    "from tvm.relax.backend.cuda.cublas import partition_for_cublas\n",
    "from tvm.relax.backend.cuda.cutlass import partition_for_cutlass\n",
    "from tvm.relax.dpl.pattern import (\n",
    "    is_op,\n",
    "    is_tuple_get_item,\n",
    "    make_fused_bias_activation_pattern,\n",
    "    wildcard,\n",
    ")\n",
    "from tvm.relax.transform import PatternCheckContext\n",
    "from tvm.script import ir as I\n",
    "from tvm.script import relax as R\n",
    "from tvm.script import tir as T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建简单的计算图，只包含一个乘法算子。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = relax.Var(\"x\", relax.TensorStructInfo([10, 10], \"float32\"))\n",
    "y = relax.Var(\"y\", relax.TensorStructInfo([10, 10], \"float32\"))\n",
    "bb = relax.BlockBuilder()\n",
    "with bb.function(\"main\", [x, y]):\n",
    "    with bb.dataflow():\n",
    "        lv0 = bb.emit(relax.op.multiply(x, y))\n",
    "        gv = bb.emit_output(lv0)\n",
    "    bb.emit_func_output(gv)\n",
    "mod = bb.get()\n",
    "mod = relax.transform.CanonicalizeBindings()(mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(x, y)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当前有两种 BYOC (自带代码生成)的实现路径：\n",
    "\n",
    "- 路径1：`[FuseOpsByPattern(patterns, annotate_codegen=True), RunCodegen()]`\n",
    "- 路径2：`[FuseOpsByPattern(patterns, annotate_codegen=False), MergeCompositeFunctions(), RunCodegen()]`\n",
    "为了保持一致性，两条路径都应具有与 `RunCodegen()` 相同的接口。\n",
    "\n",
    "关键点说明：\n",
    "\n",
    "- 两种路径最终功能等效但实现方式不同\n",
    "- 路径1直接生成带代码生成属性的融合函数\n",
    "- 路径2先进行基础融合，再通过 `MergeCompositeFunctions` 添加代码生成属性\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_multiply_cutlass</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Codegen&quot;</span>: <span style=\"color: #BA2121\">&quot;cutlass&quot;</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "        \n",
       "        <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "        <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">local_func</span>(x_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;cutlass.multiply&quot;</span>})\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "                gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(x_1, y_1)\n",
       "                R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "            <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "        output: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> local_func(x, y)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> output\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_multiply_cutlass(x, y)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "patterns = [(\"cutlass.multiply\", is_op(\"relax.multiply\")(wildcard(), wildcard()))]\n",
    "mod1 = relax.transform.FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=True)(\n",
    "    mod\n",
    ")\n",
    "assert tvm.relax.analysis.well_formed(mod1)\n",
    "mod1.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "路径二："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">fused_relax_multiply1_cutlass</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Codegen&quot;</span>: <span style=\"color: #BA2121\">&quot;cutlass&quot;</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "        \n",
       "        <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "        <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">gv</span>(x_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;Composite&quot;</span>: <span style=\"color: #BA2121\">&quot;cutlass.multiply&quot;</span>})\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "                gv_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(x_1, y_1)\n",
       "                R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv_1)\n",
       "            <span style=\"color: #008000; font-weight: bold\">return</span> gv_1\n",
       "\n",
       "        gv_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> gv(x, y)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv_1\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>fused_relax_multiply1_cutlass(x, y)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod2 = relax.transform.FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=False)(\n",
    "    mod\n",
    ")\n",
    "mod2 = relax.transform.MergeCompositeFunctions()(mod2)\n",
    "assert tvm.relax.analysis.well_formed(mod2)\n",
    "mod2.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
